Perceptron Binary Classification Learning Algorithm Tutorial

註:這個 Tutorial 主要還是介紹怎麼使用 FukuML,如果非必要並不會涉入太多演算法或數學式的細節,若大家對機器學習有興趣,還是建議觀看完整的課程。

Perceptron Binary Classification Learning Algorithm(PLA)是最基礎的機器學習算法,主要用在讓機器學習分類,基礎我們會使用在二元分類,再慢慢推廣至多元分類。其核心想法也不難,追根究底就是個知錯能改的演算法,只要有錯就修正分類器,直到不會犯錯為止。PLA 也是一個最基礎的類神經網路的運算神經元,現在很紅的 Deep Learning 的最基礎概念其實就是 PLA,因此了解 PLA 對未來學習機器學習這門課程是很有幫助的。

底下列出幾個 PLA 相關的數學式,方便大家日後學習時查閱:

PLA 假設

$$ h(x) = sign(w^Tx) $$

表示 PLA 對資料每一個維度的權重假設,這個權重向量在式子中以 w 表示,所以我們利用 PLA 學習出最能夠分好類的 w 之後,將 x 丟進去這個 PLA 假設,它就會告訴你分類的結果。

PLA 犯錯

$$ sign(w_t^Tx_{n(t)}) \neq y_{n(t)} $$

表示 PLA 對哪個資料點是預測錯誤的,其實就是對目前的假設 $w_t$ 對 $x_{n(t)}$ 點進行內積再取正負號,如果與 $y_{n(t)}$ 不同,那就代表 PLA 犯錯了。

PLA 修正假設

$$ w_{t+1} = w_t + y_{n(t)}x_{n(t)} $$

表示 PLA 犯錯之後怎麼修正,如果 PLA 猜 +1 但答案是 -1,那就往 $-1(x_{n(t)})$ 對 $w_t$ 做修正;如果 PLA 猜 -1 但答案是 +1,那就往 $+1(x_{n(t)})$ 對 $w_t$ 做修正。

使用 FukuML 的 PLA 做二元分類

接下來讓我們一步一步學習如何使用 FukuML 的 PLA 來做二元分類,首先讓我們將 PLA 引進來:


In [1]:
import FukuML.PLA as pla

然後建構一個 PLA 二元分類物件:


In [2]:
pla_bc = pla.BinaryClassifier()

我希望 FukuML 能儘量簡單易用,因此大家只要牢記 1. 載入訓練資料 -> 2. 設定參數 -> 3. 初始化 -> 4. 訓練 -> 5. 預測 這五個步驟就可以完成機器學習了~

現在第一個步驟要先載入訓練資料,但如果現在要讓大家生出一筆訓練資料應該會有困難,所以 FukuML 每個機器學習演算法都會有一個 Demo 用的內建資料,讓我們先用 Demo 用的內建資料來試試看。


In [3]:
pla_bc.load_train_data()


Out[3]:
(array([[ 1.      ,  0.97681 ,  0.10723 ,  0.64385 ,  0.29556 ],
        [ 1.      ,  0.67194 ,  0.2418  ,  0.83075 ,  0.42741 ],
        [ 1.      ,  0.20619 ,  0.23321 ,  0.81004 ,  0.98691 ],
        ..., 
        [ 1.      ,  0.50468 ,  0.99699 ,  0.75136 ,  0.51681 ],
        [ 1.      ,  0.55852 ,  0.067689,  0.666   ,  0.98482 ],
        [ 1.      ,  0.83188 ,  0.66817 ,  0.23403 ,  0.72472 ]]),
 array([ 1.,  1.,  1.,  1.,  1.,  1., -1.,  1., -1., -1.,  1.,  1.,  1.,
        -1., -1.,  1.,  1.,  1., -1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,
        -1.,  1.,  1., -1., -1.,  1.,  1., -1.,  1.,  1., -1., -1.,  1.,
        -1., -1.,  1., -1.,  1.,  1.,  1., -1., -1.,  1.,  1.,  1.,  1.,
         1.,  1.,  1.,  1.,  1., -1., -1.,  1., -1.,  1., -1., -1.,  1.,
        -1.,  1., -1., -1.,  1.,  1.,  1., -1.,  1.,  1.,  1.,  1.,  1.,
         1., -1.,  1.,  1.,  1., -1.,  1.,  1., -1.,  1.,  1.,  1.,  1.,
         1.,  1.,  1., -1.,  1., -1.,  1.,  1., -1.,  1.,  1.,  1.,  1.,
        -1.,  1.,  1.,  1.,  1., -1.,  1., -1.,  1.,  1., -1.,  1.,  1.,
         1.,  1., -1.,  1., -1., -1., -1.,  1.,  1.,  1.,  1.,  1.,  1.,
         1., -1., -1.,  1.,  1., -1.,  1., -1.,  1.,  1.,  1., -1.,  1.,
        -1., -1.,  1., -1., -1.,  1.,  1.,  1.,  1., -1.,  1.,  1.,  1.,
         1.,  1.,  1.,  1.,  1., -1., -1., -1.,  1., -1.,  1., -1.,  1.,
        -1.,  1.,  1., -1., -1.,  1., -1.,  1.,  1.,  1.,  1.,  1.,  1.,
         1.,  1., -1.,  1.,  1., -1.,  1.,  1.,  1.,  1.,  1., -1.,  1.,
         1.,  1.,  1.,  1.,  1., -1., -1., -1., -1.,  1., -1.,  1.,  1.,
        -1.,  1., -1., -1.,  1.,  1.,  1.,  1.,  1.,  1.,  1., -1.,  1.,
        -1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1., -1.,
         1., -1.,  1.,  1., -1.,  1.,  1.,  1.,  1., -1.,  1., -1.,  1.,
         1.,  1.,  1.,  1., -1.,  1., -1.,  1.,  1.,  1., -1., -1.,  1.,
         1.,  1.,  1.,  1., -1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,
         1., -1.,  1.,  1.,  1.,  1., -1.,  1.,  1.,  1.,  1.,  1.,  1.,
         1.,  1.,  1.,  1.,  1.,  1.,  1., -1.,  1., -1.,  1.,  1.,  1.,
        -1., -1., -1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1., -1.,
         1.,  1.,  1., -1.,  1.,  1., -1., -1., -1.,  1.,  1., -1., -1.,
         1., -1., -1., -1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1., -1.,
         1.,  1.,  1.,  1., -1.,  1.,  1., -1., -1.,  1., -1.,  1.,  1.,
        -1.,  1.,  1.,  1.,  1.,  1., -1.,  1.,  1.,  1.,  1.,  1.,  1.,
         1.,  1., -1.,  1.,  1.,  1., -1.,  1., -1.,  1.,  1.,  1., -1.,
         1.,  1.,  1., -1., -1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.]))

這樣就載入了 PLA 的 Demo 訓練資料,不信的話大家可以使用 pla_bc.train_Xpla_bc.train_Y 印出來看看:


In [4]:
print(pla_bc.train_X)


[[ 1.        0.97681   0.10723   0.64385   0.29556 ]
 [ 1.        0.67194   0.2418    0.83075   0.42741 ]
 [ 1.        0.20619   0.23321   0.81004   0.98691 ]
 ..., 
 [ 1.        0.50468   0.99699   0.75136   0.51681 ]
 [ 1.        0.55852   0.067689  0.666     0.98482 ]
 [ 1.        0.83188   0.66817   0.23403   0.72472 ]]

訓練資料的特徵資料就存在 train_X 中,矩陣的每一個列就代表一筆資料,然後每一個行就代表一個特徵值,請注意矩陣的第一行都是 1,這是我們演算法自己補上的 $x_0$,並不是原本訓練資料就會有的特徵值,以這個 Demo 資料來說,每筆資料只有 4 個特徵值(feature),像第一筆資料的 4 個特徵值就是 0.97681 0.10723 0.64385 0.29556,演算法將前面補上 $x_0 = 1$,就變成了現在看到的樣子。


In [5]:
print(pla_bc.train_Y)


[ 1.  1.  1.  1.  1.  1. -1.  1. -1. -1.  1.  1.  1. -1. -1.  1.  1.  1.
 -1.  1.  1.  1.  1.  1.  1.  1. -1.  1.  1. -1. -1.  1.  1. -1.  1.  1.
 -1. -1.  1. -1. -1.  1. -1.  1.  1.  1. -1. -1.  1.  1.  1.  1.  1.  1.
  1.  1.  1. -1. -1.  1. -1.  1. -1. -1.  1. -1.  1. -1. -1.  1.  1.  1.
 -1.  1.  1.  1.  1.  1.  1. -1.  1.  1.  1. -1.  1.  1. -1.  1.  1.  1.
  1.  1.  1.  1. -1.  1. -1.  1.  1. -1.  1.  1.  1.  1. -1.  1.  1.  1.
  1. -1.  1. -1.  1.  1. -1.  1.  1.  1.  1. -1.  1. -1. -1. -1.  1.  1.
  1.  1.  1.  1.  1. -1. -1.  1.  1. -1.  1. -1.  1.  1.  1. -1.  1. -1.
 -1.  1. -1. -1.  1.  1.  1.  1. -1.  1.  1.  1.  1.  1.  1.  1.  1. -1.
 -1. -1.  1. -1.  1. -1.  1. -1.  1.  1. -1. -1.  1. -1.  1.  1.  1.  1.
  1.  1.  1.  1. -1.  1.  1. -1.  1.  1.  1.  1.  1. -1.  1.  1.  1.  1.
  1.  1. -1. -1. -1. -1.  1. -1.  1.  1. -1.  1. -1. -1.  1.  1.  1.  1.
  1.  1.  1. -1.  1. -1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1. -1.
  1. -1.  1.  1. -1.  1.  1.  1.  1. -1.  1. -1.  1.  1.  1.  1.  1. -1.
  1. -1.  1.  1.  1. -1. -1.  1.  1.  1.  1.  1. -1.  1.  1.  1.  1.  1.
  1.  1.  1.  1. -1.  1.  1.  1.  1. -1.  1.  1.  1.  1.  1.  1.  1.  1.
  1.  1.  1.  1.  1. -1.  1. -1.  1.  1.  1. -1. -1. -1.  1.  1.  1.  1.
  1.  1.  1.  1.  1. -1.  1.  1.  1. -1.  1.  1. -1. -1. -1.  1.  1. -1.
 -1.  1. -1. -1. -1.  1.  1.  1.  1.  1.  1.  1.  1. -1.  1.  1.  1.  1.
 -1.  1.  1. -1. -1.  1. -1.  1.  1. -1.  1.  1.  1.  1.  1. -1.  1.  1.
  1.  1.  1.  1.  1.  1. -1.  1.  1.  1. -1.  1. -1.  1.  1.  1. -1.  1.
  1.  1. -1. -1.  1.  1.  1.  1.  1.  1.  1.  1.]

然後訓練資料的答案就存在 train_Y 中,也就代表每筆訓練資料的答案是什麼,正分類就是 1,負分類就是 -1。

接下來讓我們進行下一步,設定參數:


In [6]:
pla_bc.set_param(loop_mode='naive_cycle', step_alpha=1)


Out[6]:
('naive_cycle', 1)

PLA 這個演算法我只提供兩個參數可以調,一個是 loop_mode,用來調整 PLA 選擇訓練資料來檢查自己猜錯或猜對的選法,預設是使用 naive_cycle ,會照著訓練資料的順序一個一個檢測,有錯就修正 w。你也可以設成使用 random,這樣 PLA 檢測時就會隨便選擇一個點來檢測,有錯就修正 w。

另一個參數是 step_alpha,用來調整 PLA 每次有錯就修正 w 要修正多少量,原則上設成 1 就可以了。

接下來就可以再進行下一步,初始化:


In [7]:
pla_bc.init_W()


Out[7]:
array([ 0.,  0.,  0.,  0.,  0.])

初始化時,我們可以得到一個最初的權重值 w,通常就是個 0 向量了,但有時我們可以用 Linear Regression 來初始化,加速演算法,之後我們會再介紹,一樣我們將初始化的 w 印出來看看:


In [8]:
print(pla_bc.W)


[ 0.  0.  0.  0.  0.]

好!果然是 0 向量,一切準備就緒,接下來就是重頭戲「訓練」了:


In [9]:
pla_bc.train()


Out[9]:
array([-3.       ,  3.0841436, -1.583081 ,  2.391305 ,  4.5287635])

登登登!訓練完成,我們會得到一個全新的權重值 w,根據 PLA 的運算,這個 w 可以將資料完全分類正確!這就是機器學習神奇的地方!

我們一樣把 PLA 計算出來的 w 印出來看看:


In [10]:
print(pla_bc.W)


[-3.         3.0841436 -1.583081   2.391305   4.5287635]

果然不是個 0 向量了呢!

有了這個 w,我們就可以用它來預測未來的資料,讓我拿一筆測試資料 0.97959 0.40402 0.96303 0.28133 1 來預測看看,前面 4 個值是這筆測試資料的特徵值,後面的 1 代表這筆測試資料的答案,我們來看看預測結果:


In [11]:
test_data = '0.97959 0.40402 0.96303 0.28133 1'
prediction = pla_bc.prediction(test_data)

將預測結果印出來看看:


In [12]:
print(prediction)


{'prediction': 1.0, 'input_data_x': array([ 1.     ,  0.97959,  0.40402,  0.96303,  0.28133]), 'input_data_y': 1.0}

prediction 這個方法會把預測結果回傳成一個 dictionary,預測結果的 key 是 prediction,value 是 1,測試資料的答案也是 1,所以 PLA 正確預測了結果!

假設我們現在要預測的是未知的資料、一些我們還沒有分好類的資料,那我們就是把資料特徵值向量丟進去 prediction 方法,並設定 `mode='future_data',代表是做未知資料的預測,就可以進行預測了,比如丟進去 0.29634 0.4012 0.40266 0.67864 這筆特徵資料試試看:


In [13]:
future_data = '0.29634 0.4012 0.40266 0.67864'
prediction = pla_bc.prediction(future_data, mode='future_data')

將預測結果印出來看看:


In [14]:
print(prediction)


{'prediction': 1.0, 'input_data_x': array([ 1.     ,  0.29634,  0.4012 ,  0.40266,  0.67864]), 'input_data_y': None}

PLA 會忠實的觀察資料給出答案,它認為這筆資料的答案也是 1。(事實上真的是)

當然,如果只是看一、兩筆資料猜對,大家可能會認為這只是運氣好,所以我們必須計算 PLA 在整個訓練資料集及整個測試資料集的預測表現如何。我們提供了很簡易的方法可以計算整體的錯誤率,如果要看 PLA 在整個訓練資料集的預測錯誤率($E_{in}$):


In [15]:
print(pla_bc.calculate_avg_error(pla_bc.train_X, pla_bc.train_Y, pla_bc.W))


0.0

PLA 在訓練資料的預測錯誤率是完美的 0!這是當然的,因為 PLA 在線性可分的資料裡,一定會調整到沒有錯誤為止。

現在我們來看看 PLA 在整個測試資料集的預測錯誤率($E_{out}$),在此之前,我們必須先載入測試資料集,一樣 FukuML 有提供 Demo 版本的測試資料集:


In [16]:
pla_bc.load_test_data()


Out[16]:
(array([[ 1.      ,  0.97959 ,  0.40402 ,  0.96303 ,  0.28133 ],
        [ 1.      ,  0.29634 ,  0.4012  ,  0.40266 ,  0.67864 ],
        [ 1.      ,  0.34922 ,  0.99751 ,  0.23234 ,  0.52115 ],
        [ 1.      ,  0.65637 ,  0.7181  ,  0.72843 ,  0.93113 ],
        [ 1.      ,  0.079695,  0.57218 ,  0.70591 ,  0.33812 ],
        [ 1.      ,  0.71206 ,  0.51569 ,  0.18168 ,  0.5557  ],
        [ 1.      ,  0.17528 ,  0.2625  ,  0.8306  ,  0.029669],
        [ 1.      ,  0.93895 ,  0.93941 ,  0.72496 ,  0.95655 ],
        [ 1.      ,  0.046136,  0.94413 ,  0.038311,  0.26812 ],
        [ 1.      ,  0.072491,  0.2242  ,  0.62592 ,  0.67238 ]]),
 array([ 1.,  1., -1.,  1., -1.,  1., -1.,  1., -1.,  1.]))

載入測試資料之後,我們就可以計算 PLA 在測試資料集的預測錯誤率($E_{out}$):


In [17]:
print(pla_bc.calculate_test_data_avg_error())


0.0

PLA 在測試資料的預測錯誤率也是完美的 0,當然這某種程度是因為我們的 Demo 資料有設計過,不過理論上測試資料的預測錯誤率應該不會和訓練資料的預測錯誤率差太多,只要實驗過程是一個客觀的過程、沒有經過人為的污染,機器學習的演算法的確可以做到正確的預測。

以上,你大概已經學會使用 FukuML 提供的 PLA 做訓練,然後使用訓練完成的 w 來進行未知資料的預測了,真的五個步驟就可以做完了!很簡單吧!

使用自己的訓練資料集和測試資料集

前面的教學我們是使用 FukuML 所提供的訓練資料集和測試資料集,真實情況你當然使用自己的資料,那要怎麼做呢?FukuML 提供了很簡易的方法可以讓大家載入自己的資料:

your_training_data_file = '/path/to/your/training_data/file'
pla_bc.load_train_data(your_training_data_file)

your_testing_data_file = '/path/to/your/testing_data/file'
pla_bc.load_test_data(your_testing_data_file)

就是這麼簡單,讓我們來實際演示一下:


In [18]:
pla_bc = pla.BinaryClassifier()
pla_bc.load_train_data('/Users/fukuball/Projects/fuku-ml/FukuML/dataset/linear_separable_train.dat')


Out[18]:
(array([[ 1.        , -0.49475104,  1.60851023],
        [ 1.        ,  0.99350955,  2.53942025],
        [ 1.        ,  0.67365802,  2.41859411],
        [ 1.        , -1.91676615,  0.48923093],
        [ 1.        , -0.80964166,  1.26206511],
        [ 1.        , -0.45285374,  1.82885284],
        [ 1.        ,  0.27463815,  2.08049683],
        [ 1.        ,  0.89694355,  3.7834262 ],
        [ 1.        , -1.72520564,  0.87640485],
        [ 1.        ,  0.7349451 ,  3.39882197],
        [ 1.        , -1.02461018,  1.44258081],
        [ 1.        , -0.60392455,  0.98807458],
        [ 1.        ,  0.08098387,  2.15878467],
        [ 1.        ,  0.48213089,  2.18476304],
        [ 1.        ,  0.74123261,  3.22706092],
        [ 1.        , -0.57649605,  0.27757466],
        [ 1.        , -1.60301663,  0.85311484],
        [ 1.        , -1.90040634,  1.14021401],
        [ 1.        ,  0.7943513 ,  2.68559323],
        [ 1.        ,  0.15398661,  2.61447653],
        [ 1.        ,  2.4192871 ,  4.18943591],
        [ 1.        ,  0.13016586,  2.53128795],
        [ 1.        , -1.00057111,  1.2998211 ],
        [ 1.        , -2.24935866, -0.51829791],
        [ 1.        , -0.11745011,  2.36365622],
        [ 1.        , -0.18131864,  1.90732415],
        [ 1.        , -1.0669876 ,  1.84490598],
        [ 1.        , -0.41819858,  1.20384123],
        [ 1.        , -1.27557363,  1.58879675],
        [ 1.        , -0.48455613,  1.56688674],
        [ 1.        ,  1.76857878,  2.70393626],
        [ 1.        ,  0.6178306 ,  1.41965757],
        [ 1.        ,  0.24021005,  2.07796794],
        [ 1.        , -0.40745049,  1.12846498],
        [ 1.        , -0.5450063 ,  1.64924578],
        [ 1.        , -0.89149772,  1.29851015],
        [ 1.        ,  1.12855231,  1.96797717],
        [ 1.        , -0.62563244,  1.87988573],
        [ 1.        , -0.91508504,  1.62532636],
        [ 1.        , -1.21008395,  0.41751392],
        [ 1.        ,  1.0369232 ,  3.32224131],
        [ 1.        ,  0.82678315,  3.36840655],
        [ 1.        ,  0.00522133,  2.57820823],
        [ 1.        , -1.06147755,  1.06473163],
        [ 1.        , -0.45700467,  2.00276916],
        [ 1.        ,  0.13487671,  0.65962212],
        [ 1.        ,  1.22494293,  2.4905672 ],
        [ 1.        ,  0.82587401,  1.6469229 ],
        [ 1.        , -0.46393125,  2.81795857],
        [ 1.        , -2.36851079,  1.01187775],
        [ 1.        ,  0.12587105,  2.55995705],
        [ 1.        , -0.35712397,  1.88322814],
        [ 1.        ,  0.68857731,  2.45378334],
        [ 1.        , -1.11846239,  1.7060288 ],
        [ 1.        , -1.73549484,  1.16778056],
        [ 1.        ,  0.18491969,  1.6888773 ],
        [ 1.        ,  1.10350087,  2.55392247],
        [ 1.        , -0.44246031,  1.49684599],
        [ 1.        , -0.22148107,  2.66175094],
        [ 1.        ,  0.30829778,  2.25791677],
        [ 1.        , -0.29287034,  2.04485062],
        [ 1.        , -0.44357665,  1.58064718],
        [ 1.        ,  0.01694366,  1.60172119],
        [ 1.        , -0.35169509,  2.20195385],
        [ 1.        ,  0.55527319,  1.84184212],
        [ 1.        , -0.59067181,  1.19101348],
        [ 1.        ,  0.21534601,  2.95975298],
        [ 1.        ,  0.79769729,  2.79259136],
        [ 1.        ,  0.41191044,  1.86899517],
        [ 1.        , -1.39417234,  1.15327164],
        [ 1.        ,  0.71641377,  2.87832566],
        [ 1.        ,  0.44264983,  3.19840287],
        [ 1.        , -1.11935978,  1.0214965 ],
        [ 1.        ,  0.25788802,  3.2897688 ],
        [ 1.        , -0.33296609,  0.88930833],
        [ 1.        ,  2.07653897,  3.77544829],
        [ 1.        , -0.35754265,  0.91029036],
        [ 1.        ,  0.38998221,  2.7169355 ],
        [ 1.        ,  0.88980695,  2.23294531],
        [ 1.        , -0.30374769,  2.16560662],
        [ 1.        ,  0.46858362,  2.22595082],
        [ 1.        ,  1.5953943 ,  3.86456874],
        [ 1.        ,  0.80516593,  1.14755445],
        [ 1.        , -1.07078848,  1.07630871],
        [ 1.        , -0.61184666,  1.08231727],
        [ 1.        , -1.22082978,  1.441157  ],
        [ 1.        ,  0.42133054,  2.00527312],
        [ 1.        , -1.15371133,  0.39545553],
        [ 1.        , -1.16529981,  0.55726593],
        [ 1.        , -0.0753288 ,  2.65117295],
        [ 1.        ,  1.6801046 , -0.08598257],
        [ 1.        ,  1.8497788 , -0.01692729],
        [ 1.        ,  1.85056814,  0.08177583],
        [ 1.        ,  2.80034537,  0.81539772],
        [ 1.        ,  0.40956132,  0.55189673],
        [ 1.        ,  1.31484126, -0.82314199],
        [ 1.        ,  0.74578981, -0.63417363],
        [ 1.        ,  1.52094099,  0.07809939],
        [ 1.        ,  2.07568652, -0.1150392 ],
        [ 1.        ,  0.99162865, -0.53539644],
        [ 1.        ,  0.7038055 , -0.19221186],
        [ 1.        ,  2.06231887, -0.03769865],
        [ 1.        ,  2.5799191 ,  0.07069567],
        [ 1.        ,  0.87544838, -0.5637692 ],
        [ 1.        ,  2.11936394, -0.12933603],
        [ 1.        ,  1.72361776, -0.83946543],
        [ 1.        ,  1.00174566, -0.38122767],
        [ 1.        ,  2.67256573,  1.39825995],
        [ 1.        ,  1.95670692, -0.76033255],
        [ 1.        , -0.12377975, -1.66555773],
        [ 1.        ,  3.49754388,  0.86099572],
        [ 1.        ,  2.30620253,  1.41859895],
        [ 1.        ,  2.55608362,  0.42400762],
        [ 1.        ,  2.48991027,  0.76553132],
        [ 1.        ,  2.55940671, -0.57984346],
        [ 1.        ,  2.32904612, -0.51945896],
        [ 1.        ,  1.7353993 , -0.75519976],
        [ 1.        ,  2.51829003,  0.37786517],
        [ 1.        ,  1.8706277 , -0.93869733],
        [ 1.        ,  0.22236542, -2.44483319],
        [ 1.        , -0.09038213, -1.79941358],
        [ 1.        ,  2.70225343,  0.94731516],
        [ 1.        ,  1.88566698,  0.2798723 ],
        [ 1.        ,  1.58910203, -0.70947294],
        [ 1.        ,  3.0973127 ,  0.92156856],
        [ 1.        ,  2.69809437,  0.38175539],
        [ 1.        ,  3.3289721 ,  0.41484516],
        [ 1.        ,  1.87232143, -0.61040703],
        [ 1.        ,  2.33296818,  0.02254511],
        [ 1.        ,  1.81407758, -0.17053957],
        [ 1.        ,  2.96737955,  1.44181063],
        [ 1.        ,  2.9380551 ,  0.47943516],
        [ 1.        ,  0.73973305, -1.28050721],
        [ 1.        ,  2.08422916, -1.39634791],
        [ 1.        ,  1.66061566, -0.98458495],
        [ 1.        ,  1.98635728, -0.28509211],
        [ 1.        ,  2.56435931, -0.47953988],
        [ 1.        ,  2.20241294,  0.39511807],
        [ 1.        ,  3.87224268,  0.76225007],
        [ 1.        ,  1.64152869,  0.42732398],
        [ 1.        ,  2.01228488,  0.83947942],
        [ 1.        ,  1.12248906, -0.86437473],
        [ 1.        ,  0.92242964,  0.68317263],
        [ 1.        ,  1.60673796,  0.18415559],
        [ 1.        ,  1.50243849,  0.01270292],
        [ 1.        ,  2.85000032,  0.26811154],
        [ 1.        ,  1.11113213, -1.08291552],
        [ 1.        ,  1.40913806, -0.73020544],
        [ 1.        ,  1.73676161,  0.22610285],
        [ 1.        ,  2.55634878,  0.73043033],
        [ 1.        ,  1.68015105, -0.51196788],
        [ 1.        ,  2.07339426, -0.69056912],
        [ 1.        ,  2.51559063,  0.28962156],
        [ 1.        ,  1.91898721, -0.38112982],
        [ 1.        ,  2.84276804,  1.00034359],
        [ 1.        ,  2.48044735,  0.97198591],
        [ 1.        ,  1.00833589, -1.14853349],
        [ 1.        ,  1.33359442, -0.28894411],
        [ 1.        ,  1.77539826, -1.30014904],
        [ 1.        ,  0.67101677, -1.91851934],
        [ 1.        ,  3.17326187,  1.70334529],
        [ 1.        ,  2.08009842, -0.74647001],
        [ 1.        ,  2.29719679, -0.16319915],
        [ 1.        ,  1.83769012, -0.19705278],
        [ 1.        ,  1.80368554, -0.12629935],
        [ 1.        ,  1.56859964, -0.2302815 ],
        [ 1.        ,  1.47931541, -0.54655241],
        [ 1.        ,  1.9768354 , -0.26445185],
        [ 1.        ,  2.50410495,  1.52574828],
        [ 1.        ,  2.35432514, -0.47573114],
        [ 1.        ,  1.07167578,  0.15158119],
        [ 1.        ,  1.11479351, -1.433885  ],
        [ 1.        ,  1.12664849,  0.11512314],
        [ 1.        ,  2.08937432, -0.48560459],
        [ 1.        ,  0.9236267 , -0.89739827],
        [ 1.        ,  2.22914995,  0.08347266],
        [ 1.        ,  1.87811234, -0.22916618],
        [ 1.        ,  3.32948946,  0.77274808],
        [ 1.        ,  2.60176922,  0.08437287],
        [ 1.        ,  1.98635292,  0.81549389]]),
 array([ 1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,
         1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,
         1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,
         1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,
         1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,
         1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,
         1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1., -1.,
        -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.,
        -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.,
        -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.,
        -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.,
        -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.,
        -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.,
        -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.]))

In [19]:
pla_bc.load_test_data('/Users/fukuball/Projects/fuku-ml/FukuML/dataset/linear_separable_test.dat')


Out[19]:
(array([[  1.00000000e+00,   3.12179041e-01,   3.26300582e+00],
        [  1.00000000e+00,  -7.88545922e-01,   1.84177454e+00],
        [  1.00000000e+00,   3.45018856e-01,   2.02971487e+00],
        [  1.00000000e+00,  -5.90936663e-02,   2.06095580e+00],
        [  1.00000000e+00,  -6.04229133e-01,   1.89545186e+00],
        [  1.00000000e+00,   2.92639463e-01,   2.21847534e+00],
        [  1.00000000e+00,   1.37291076e+00,   3.10397301e+00],
        [  1.00000000e+00,  -8.55850926e-01,   7.43968659e-01],
        [  1.00000000e+00,  -2.33116362e-04,   1.45262917e+00],
        [  1.00000000e+00,   5.63747692e-01,   2.65759454e+00],
        [  1.00000000e+00,   3.27474170e+00,   5.16394190e-01],
        [  1.00000000e+00,   1.00982446e+00,  -9.92472127e-01],
        [  1.00000000e+00,   1.63602318e+00,  -7.66844250e-01],
        [  1.00000000e+00,   2.81507689e+00,   2.63441093e-01],
        [  1.00000000e+00,   1.83736479e+00,  -8.03493918e-01],
        [  1.00000000e+00,   2.26025418e+00,   4.02606276e-01],
        [  1.00000000e+00,   2.18689341e+00,   7.86296427e-01],
        [  1.00000000e+00,   1.34179804e+00,   1.26613719e-04],
        [  1.00000000e+00,   2.75190511e+00,  -3.67967235e-01],
        [  1.00000000e+00,   2.92122264e+00,   1.35934066e-01]]),
 array([ 1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1., -1., -1., -1.,
        -1., -1., -1., -1., -1., -1., -1.]))

看吧,都順利載入資料了,接下來的問題只剩下資料集的格式是怎麼樣,這個可以直接看 FukuML 提供的資料集一窺究竟:

https://github.com/fukuball/fuku-ml/blob/master/FukuML/dataset/pla_binary_train.dat

其實格式真的很簡單,就是將每筆資料的特徵值用空格隔開,然後放成一橫行,然後將這筆資料的答案用空格隔開放在最後,答案是正分類就是 1,負分類就是 -1,這樣就完成了。

所以比如你想做銀行核卡預測,然後審核的特徵是年薪、年齡、性別,那假設小明年薪 100W、年齡 30、性別男性且通過核卡了,那這筆資料就是:

100 30 1 1

假設小華年薪 20W、年齡 25、性別男性,沒有通過核卡,這筆資料就是:

20 25 1 -1

假設小美年薪 30W、年齡 24、性別女性,有過核卡,這筆資料就是:

30 24 0 1

以此類推,簡簡單單、輕輕鬆鬆,大家就可以使用自己的資料來玩玩看機器學習囉~

使用二維資料來幫助理解

其實 PLA 分類演算法計算出來的 w 就是去找出一條可以將資料點完美分開的線,書本上的範例可能會使用二維的資料集並畫成圖示呈現給大家看,但在真實世界中,我們的資料通常不會只是二維的,這樣找出來的 w 就會是一個在高維空間將資料完美分類的超平面,我們很難在平面上呈現這樣的結果,因此還是請大家多去從抽象化的高維空間去思考機器學習的過程,不要遷就於圖示。不過如果你剛接觸機器學習,使用二維資料來慢慢理解機器學習演算法也是一個不錯的學習方法,我這邊稍微展示一下如何印出二維資料點及機器學習訓練出來的 w。

載入資料點時,我們就可以在平面上印出所有的資料點,正分類印成紅色的,負分類印成藍色的:


In [20]:
%matplotlib inline

import FukuML.PLA as pla
import matplotlib.pyplot as plt

pla_bc = pla.BinaryClassifier()
pla_bc.load_train_data('/Users/fukuball/Projects/fuku-ml/FukuML/dataset/linear_separable_train.dat')

for idx, val in enumerate(pla_bc.train_Y):
    if val==1:
        plt.plot(pla_bc.train_X[idx,1], pla_bc.train_X[idx,2], "ro")
    else:
        plt.plot(pla_bc.train_X[idx,1], pla_bc.train_X[idx,2], "bo")
        
plt.axis("tight")
plt.show()


機器訓練完之後,我們可以得到 w,這時只要使用 $w_2*x_2+w_1*x_1+w_0*x_0=0$ 的線性方程式找出斜率,就可以在平面上畫出 w:


In [21]:
pla_bc.set_param(loop_mode='naive_cycle', step_alpha=1)
pla_bc.init_W()
pla_bc.train()

for idx, val in enumerate(pla_bc.train_Y):
    if val==1:
        plt.plot(pla_bc.train_X[idx,1], pla_bc.train_X[idx,2], "ro")
    else:
        plt.plot(pla_bc.train_X[idx,1], pla_bc.train_X[idx,2], "bo")

a0 = -4;
a1 = (-pla_bc.W[0]-pla_bc.W[1]*a0)/pla_bc.W[2]
b0 = 4;
b1 = (-pla_bc.W[0]-pla_bc.W[1]*b0)/pla_bc.W[2]

plt.plot([a0, b0], [a1, b1], "k")

plt.axis("tight")
plt.show()


這樣就可以畫成圖示了,但記得圖示只是用來幫助理解,在使用或學習機器學習這門課時,千萬不能被圖示牽著走喔!